import bluepysnap
import sys
import seaborn as sns
import pandas as pd
from scipy.special import expit
import scipy as sc
from scipy.optimize import curve_fit
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt


def get_spindle_amp(sim, filename=None):
    report = sim.spikes["thalamus_neurons"].spike_report.filter(group={"mtype": "Rt_RC"}).report

    times = report.index
    time_start = np.min(times)
    time_stop = np.max(times)
    time_binsize = 5
    bins = np.append(np.arange(time_start, time_stop, time_binsize), time_stop)
    hist, bin_edges = np.histogram(times, bins=bins)
    node_count = report[["ids", "population"]].drop_duplicates().shape[0]
    freq = 1.0 * hist / node_count / (0.001 * time_binsize)
    if filename is not None:
        plt.figure()
        plt.plot(bins[:-1], freq)
        plt.savefig(filename)
    from scipy.signal import find_peaks

    amps = []
    durations = []
    freqs = []
    for step in range(16):
        low, high = 3500 + step * 1000, 4500 + step * 1000
        mask = (bin_edges[:-1] >= low) & (bin_edges[:-1] <= high)
        f = freq[mask]
        t = bins[:-1][mask]
        peaks = find_peaks(f, height=2, width=3)[0]
        if peaks.size <= 2:
            amps.append(0)
            durations.append(0)
            freqs.append(0)
        else:
            amps.append(np.mean(f[peaks]))
            durations.append(t[peaks[-1]] - t[peaks[0]])
            freqs.append((len(peaks) - 1) / (t[peaks[-1]] - t[peaks[0]]) * 1000)
    return {"amps": amps, "durations": durations, "freqs": freqs}


def fit_func(_x, shift_exp, slope_exp=4, amp_exp=60):
    return expit((_x - shift_exp) / slope_exp) * amp_exp + 2


if __name__ == "__main__":

    colors = ["C0", "C1", "C2", "C3", "C4", "C5", "C3", "C7", "C8", "C9", "C10"]
    base_path = Path(
        "../../../simulations/06_09_Iahp_variants/"
    )
    fig_all = plt.figure(figsize=(5, 3))
    ax_all = plt.gca()
    for perc in [80, 90, 100]:
        if Path(f"_amp_2d_data_{perc}.csv").exists():
            amp_data = pd.read_csv(f"amp_2d_data_{perc}.csv", index_col=0)
        else:
            amp_data = pd.DataFrame()
            for i, (model, var, params) in enumerate(
                zip(
                    [
                        "Ecel-like_lowest-Iahp",
                        "Ecel-like_low-Iahp",
                        "Ecel",
                        "Ecel-like_med-Iahp",
                        "Ecel-like_high-Iahp",
                        "lowest-Iahp",
                        "low-Iahp",
                        "Spp",
                        "med-Iahp",
                        "high-Iahp",
                        "Runaway-like_lowest-Iahp",
                        "Runaway-like_low-Iahp",
                        "Runaway",
                        "Runaway-like_med-Iahp",
                        "Runaway-like_high-Iahp",
                    ],
                    [
                        "18-",
                        "19-",
                        "1_",
                        "20-",
                        "21-",
                        "14_",
                        "15_",
                        "2_",
                        "16_",
                        "17_",
                        "22-",
                        "23-",
                        "3_",
                        "24-",
                        "25-",
                    ],
                    [
                        (0.0005, 0.005),
                        (0.0005, 0.01),
                        (0.0005, 0.015),
                        (0.0005, 0.02),
                        (0.0005, 0.025),
                        (0.0007, 0.005),
                        (0.0007, 0.01),
                        (0.0007, 0.015),
                        (0.0007, 0.02),
                        (0.0007, 0.025),
                        (0.001, 0.005),
                        (0.001, 0.01),
                        (0.001, 0.015),
                        (0.001, 0.02),
                        (0.001, 0.025),
                    ],
                )
            ):
                print(model)
                sim_path = (
                    base_path
                    / f"variant{var}{model}/variant{var}{model}_SUPER_LONG_{perc}percent_CT_up_down/simulation_config.json"
                )
                sim = bluepysnap.Simulation(sim_path)
                data = get_spindle_amp(sim, filename=f"figures/spindles_{model}_{perc}.pdf")
                amp_mean = np.mean(data["amps"])
                amp_data.loc[params[0], params[1]] = amp_mean

                print(amp_data)
                amp_data.to_csv(f"amp_2d_data_{perc}.csv")
        plt.figure(figsize=(7, 3))
        amp_data.index.name = "it2"
        amp_data.columns.name = "iahp"
        print(amp_data)
        sns.heatmap(
            amp_data.loc[::-1],
            cbar_kws={"shrink": 0.3, "label": "Spindle amp."},
            cmap="viridis",
            vmin=20,
            vmax=80,
            #square=True,
        )
        plt.tight_layout()
        plt.savefig(f"2d_amp_heatmap_{perc}.pdf")
